package trx

import (
	"context"

	plumeErrors "gitee.com/lipore/plume/errors"
	"gitee.com/lipore/plume/logger"
)

type Manager interface {
	Start(ctx context.Context) (Context, error)
}

type Context interface {
	context.Context
	Commit() error
	Rollback()
	RollbackTo(name string)
	SavePoint(name string)
}

var txManager Manager

func Setup(txm Manager) {
	txManager = txm
}

const (
	ErrStartTx = iota
	ErrCommit
	ErrCancel
)

func Tx(ctx context.Context, h func(Context) error) error {
	txCtx, err := startTx(ctx)
	if err != nil {
		err = plumeErrors.WithCode(err, ErrCancel, "canceling...")
		logger.Errorf("%v", err)
		return err
	}
	err = h(txCtx)
	if err != nil {
		logger.Warnf("%v", err)
		txCtx.Rollback()
		return err
	}
	err = txCtx.Commit()
	if err != nil {
		err = plumeErrors.WithCode(err, ErrCommit, "commit transaction failed, canceling")
		logger.Errorf("%v", err)
		txCtx.Rollback()
		return err
	}
	return nil
}

func getTx(ctx context.Context) Context {
	if transaction, ok := ctx.(Context); ok {
		return transaction
	}
	return nil
}

func startTx(ctx context.Context) (Context, error) {
	if tx := getTx(ctx); tx != nil {
		return tx, nil
	}
	tx, err := txManager.Start(ctx)
	if err != nil {
		err = plumeErrors.WithCode(err, ErrStartTx, "start transaction failed")
		logger.Warnf("%v", err)
		return nil, err
	}
	return tx, nil
}
